package com.shuohe.util.fitting;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.jexl2.Expression;
import org.apache.commons.jexl2.JexlContext;
import org.apache.commons.jexl2.JexlEngine;
import org.apache.commons.jexl2.MapContext;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
/**
 * 
首先要准备好待拟合的曲线数据x和y，这是两个double数组，
然后把这两个数组合并到WeightedObservedPoints对象实例中，
可以调用WeightedObservedPoints.add(x[i], y[i])将x和y序列中的数据逐个添加到观察点序列对象中。
随后创建PolynomialCurveFitter对象，创建时要指定拟合多项式的阶数，注意阶数要选择适当，
不是越高越好，否则拟合误差会很大。最后调用PolynomialCurveFitter的fit方法即可完成多项式曲线拟合，
fit方法的参数通过WeightedObservedPoints.toList()获得。拟合结果通过一个double数组返回，按元素顺序依次是常数项、一次项、二次项、……。
 */
public class FittingUtils {
	private static String beforeMethod ="function calculationValue(a){\r\n";
	private static String afterMethod ="}";
	public static void main(String[] args) throws Exception {
		Double[] inputDataX = { 1.0,2.0, 3.0};
		Double[] inputDataY = {12.0,11.0,50.0};
		//double[] inputDataY = {3,2,1};
		int n = 2;
		getFitting(inputDataX,inputDataY,n);
		//getFormula(inputDataX,inputDataY,n,1.5);
		
//		FittingUtils tcc = new FittingUtils();
//        double timeCost = tcc.calcTimeCost(new CalcCurveFitting());
//        System.out.println("--------------------------------------------------------------------------");
//        System.out.println("time cost is: " + timeCost + "s");
//        System.out.println("--------------------------------------------------------------------------");
		
    }
    
    
    /**
     * 
     * <p>Title: getFitting</p>
     * <p>Description: </p>
     * @param arrayX  x矩阵
     * @param arrayY  y矩阵
     * @param n
     * @return
  * @throws Exception 
     */
    public static String getFitting(Double[] arrayX, Double[] arrayY,int n) throws Exception {
       String s = " return ";
 	   WeightedObservedPoints points = new WeightedObservedPoints(); 
 	   if(arrayX.length==arrayY.length) {
 		   for (int index = 0; index < arrayX.length; index++) {
 			   if(null!=arrayY[index]) {
 				  points.add(arrayX[index], arrayY[index]);
 			   }
 	            
 	        }
 	   }else {
 		   throw new Exception("x坐标数组和y坐标数组长度不相等");
 	   }
 	   PolynomialCurveFitter fitter = PolynomialCurveFitter.create(n);
 	   double[] result = fitter.fit(points.toList());
 	   int length = result.length;
 	   
        for(int i=0;i<result.length;i++) {
        	if(i==0) {
        		s+=result[i];
        	}
        	if(i==1) {
        		String cs ="*a";
        		if(result[i]>0) {
        			s+="+"+result[i]+cs;
        		}else {
        			s+=result[i]+cs;
        		}
        	}
        	if(i==2) {
        		String cs ="*a*a";
        		if(result[i]>0) {
        			s+="+"+result[i]+cs;
        		}else {
        			s+=result[i]+cs;
        		}
        	}
        }
        System.out.println(beforeMethod+s+"\r\n"+afterMethod);
 	   return beforeMethod+s+"\r\n"+afterMethod;
 	   
    }
    
    /**
     * 通过公式给出X值，求Y值
     * <p>Title: getFormula</p>
     * <p>Description: </p>
     * @param arrayX  x矩阵
     * @param arrayY  y矩阵
     * @param n
     * @author xuej 20190619
     * @return
  * @throws Exception 
     */
    public static Double getFormula(Double[] arrayX, Double[] arrayY,int n,Double d) throws Exception {
       String s = "";
 	   WeightedObservedPoints points = new WeightedObservedPoints(); 
 	   if(arrayX.length==arrayY.length) {
 		   for (int index = 0; index < arrayX.length; index++) {
 			/*  System.out.println(arrayX[index]);
 			 System.out.println(arrayY[index]);*/
 			   if(null!=arrayY[index]) {
 				  points.add(arrayX[index], arrayY[index]);
 			   }
 	        }
 	   }else {
 		   throw new Exception("x坐标数组和y坐标数组长度不相等");
 	   }
 	   PolynomialCurveFitter fitter = PolynomialCurveFitter.create(n);
 	   double[] result = fitter.fit(points.toList());
 	   
       for(int i=0;i<result.length;i++) {
    	   //System.out.println(result[i]);
	    	if(i==0) {
	       		s+=result[i];
	       	}
	       	if(i==1) {
	       		String cs ="*a";
	       		if(result[i]>0) {
	       			s+="+"+result[i]+cs;
	       		}else {
	       			s+=result[i]+cs;
	       		}
	       	}
	       	if(i==2) {
	       		String cs ="*a*a";
	       		if(result[i]>0) {
	       			s+="+"+result[i]+cs;
	       		}else {
	       			s+=result[i]+cs;
	       		}
	       	}
        }
       Map<String, Object> map = new HashMap<String, Object>();
       map.put("a", d);
       //公式计算
       
       System.out.println(s);
       Double ss = (Double) convertToCode(s,map);
       System.out.println(ss);       
 	   return ss;
 	   
    }

    /**
     * java将字符串转换成可执行代码 工具类
     *
     * @param jexlExp
     * @param map
     * @return
     */
    private static Object convertToCode(String jexlExp, Map<String, Object> map) {
        JexlEngine jexl = new JexlEngine();
        Expression expression = jexl.createExpression(jexlExp);
        JexlContext jc = new MapContext();
        for (String key : map.keySet()) {
            jc.set(key, map.get(key));
        }
        if (null == expression.evaluate(jc)) {
            return "";
        }
        return expression.evaluate(jc);
    }
}

